import sys
import time
import csv
#from sklearn.neighbors import NearestNeighbors

from util.graphdb_base import GraphDBBase

class DopamineNodeClassification(GraphDBBase):
    def __init__(self, argv):
        super().__init__(command=__file__, argv=argv)

    def initialize_pipeline(self, pipeline_name):
         print("initializing pipeline")
         start = time.time()
         with self._driver.session() as session:
            tx = session.begin_transaction()
            try:
                drop_query = """CALL gds.beta.pipeline.drop($pipelineName, false)"""

                create_query = """CALL gds.beta.pipeline.nodeClassification.create($pipelineName)"""

                tx.run(drop_query, {"pipelineName": pipeline_name})
                tx.run(create_query, {"pipelineName": pipeline_name})

            except Exception as e:
                print(e)
                exit

            tx.commit()
            print("time to initialize pipeline:", time.time() - start)

    def add_node_properties(self, pipeline_name, weight_property, features):
        print("adding node properties to pipeline")
        start = time.time()
        with self._driver.session() as session:
            tx = session.begin_transaction()

            try:
                #CENTRALITY
                add_degree = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'degree',
                    {
                        mutateProperty: 'nodeDegree',
                        relationshipWeightProperty: $weightProperty
                    }
                    )
                    """

                add_betweenness = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'betweenness',
                    {
                        mutateProperty: 'nodeBetweenness',
                        relationshipWeightProperty: $weightProperty
                    }
                    )
                    """

                add_eigenvector = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'eigenvector',
                    {
                        mutateProperty: 'nodeEigenvector',
                        relationshipWeightProperty: $weightProperty
                    }
                    )
                    """

                add_page_rank = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'pageRank',
                    {
                        mutateProperty: 'nodePageRank',
                        relationshipWeightProperty: $weightProperty
                    }
                    )
                    """
                
                add_article_rank = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'articleRank',
                    {
                        mutateProperty: 'nodeArticleRank',
                        relationshipWeightProperty: $weightProperty
                    }
                    )
                    """

                add_hits = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'alpha.hits',
                    {
                        authProperty: 'nodeAuth',
                        hubProperty: 'nodeHub',
                        hitsIterations: 20
                    }
                    )
                    """
                #Takes to long to run
                add_celf = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'beta.influenceMaximization.celf',
                    {
                        mutateProperty: 'nodeCelfSpread',
                        seedSetSize: 25
                    }
                    )
                    """

                #COMUNITY DETECTION?
                #NODE EMBEDDINGS
                add_fastRP = """CALL gds.beta.pipeline.nodeClassification.addNodeProperty(
                    $pipelineName,
                    'fastRP',
                    {
                        mutateProperty: 'nodeFastRP',
                        embeddingDimension: 128,
                        relationshipWeightProperty: $weightProperty
                    }
                    )
                """
                
                #tx.run(add_degree, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                #tx.run(add_betweenness, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                #tx.run(add_eigenvector, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                #tx.run(add_page_rank, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                #tx.run(add_hits, {"pipelineName": pipeline_name})
                #tx.run(add_celf, {"pipelineName": pipeline_name})
                #tx.run(add_fastRP, {"pipelineName": pipeline_name, "weightProperty": weight_property})
                select_features = ""

                if features == "centrality":
                    tx.run(add_degree, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                    tx.run(add_betweenness, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                    tx.run(add_eigenvector, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                    tx.run(add_page_rank, {"pipelineName": pipeline_name, "weightProperty": weight_property})
                    tx.run(add_article_rank, {"pipelineName": pipeline_name, "weightProperty": weight_property}) 
                    #tx.run(add_celf, {"pipelineName": pipeline_name})
                    select_features = """
                        CALL gds.beta.pipeline.nodeClassification.selectFeatures($pipelineName,
                        ['nodeDegree', 'nodeBetweenness', 'nodeEigenvector', 'nodePageRank','nodeArticleRank'])""" #'nodeCelfSpread'
                else:
                    tx.run(add_fastRP, {"pipelineName": pipeline_name, "weightProperty": weight_property})
                    select_features = """
                        CALL gds.beta.pipeline.nodeClassification.selectFeatures($pipelineName,
                        ['nodeFastRP'])"""

                tx.run(select_features, {"pipelineName": pipeline_name})

            except Exception as e:
                print(e)
                exit

            tx.commit()
            print("time to add node properties to pipeline:", time.time() - start)


    def add_split_and_model_configuration(self, pipeline_name):
        print("configurating model pipeline")
        start = time.time()
        with self._driver.session() as session:
            tx = session.begin_transaction()

            try:

                split_configuration = """
                    CALL gds.beta.pipeline.nodeClassification.configureSplit($pipelineName, 
                    {
                        testFraction: 0.3,
                        validationFolds: 5
                    })
                """

                model_configuration = """
                    CALL gds.alpha.pipeline.nodeClassification.addRandomForest(
                        $pipelineName,
                        {
                            numberOfDecisionTrees: 100
                        })
                """

                #Add auto-tunning?

                tx.run(split_configuration, {"pipelineName": pipeline_name})
                tx.run(model_configuration, {"pipelineName": pipeline_name})

            except Exception as e:
                print(e)
                exit

            tx.commit()
            print("time to configure the model:", time.time() - start)

    def project_graph(self, projection_name, target_label, relationship, weight_property):
        print("creating graph projection")
        start = time.time()
        with self._driver.session() as session:
            tx = session.begin_transaction()

            try:
                drop = """CALL gds.graph.drop($projectionName, false)"""

                create = """
                    CALL gds.graph.project(
                        $projectionName,
                        [$targetLabel, 'DopamineASD', 'DopamineDD'],
                        $relationship,
                        {nodeProperties: ['class', 'dopamineGeneDosageVector', 'dopamineGoVector'],
                        relationshipProperties: $weightProperty
                        }
                        )
                """

                tx.run(drop, {"projectionName": projection_name})
                tx.run(create, {"projectionName": projection_name, "targetLabel": target_label, "relationship": relationship, "weightProperty": weight_property})
            except Exception as e:
                print(e)
                exit

            tx.commit()
            print("time to create the graph projection:", time.time() - start)

    def train_model(self, projection_name, pipeline_name, model_name, target_label, seed):
        print("training the model")
        start = time.time()
        metrics = {}  # Create an empty dictionary to store the metrics
        with self._driver.session() as session:
            tx = session.begin_transaction()

            try:
                drop = """CALL gds.beta.model.drop($modelName, false)"""

                train = """
                            CALL gds.beta.pipeline.nodeClassification.train($projectionName, {
                            pipeline: $pipelineName,
                            targetNodeLabels: [$targetNodeLabels],
                            modelName: $modelName,
                            targetProperty: 'class',
                            randomSeed: $seed,
                            metrics: ['ACCURACY', 'OUT_OF_BAG_ERROR', 'F1_WEIGHTED', 'PRECISION(class=*)','RECALL(class=*)']
                            }) YIELD modelInfo, modelSelectionStats
                            RETURN
                              modelInfo.bestParameters AS winningModel,
                              modelInfo.metrics.ACCURACY.train.avg AS avgTrainScore,
                              modelInfo.metrics.ACCURACY.test AS testScore,
                              modelInfo.metrics.F1_WEIGHTED.test as f1,
                              modelInfo.metrics.PRECISION_class_1.test as precision_asd,
                              modelInfo.metrics.PRECISION_class_0.test as precision_dd,
                              modelInfo.metrics.RECALL_class_1.test as recall_asd,
                              modelInfo.metrics.RECALL_class_0.test as recall_dd
                """

                tx.run(drop, {"modelName": model_name})
                result = tx.run(train, {"projectionName": projection_name, "pipelineName": pipeline_name, "targetNodeLabels": target_label, "modelName": model_name, "seed": seed})
                # Extract the metrics from the result
                for record in result:
                    metrics = {
                        "run" : seed,
                        "best_parameters" : record["winningModel"],
                        "accuracy": record["testScore"],
                        "f1": record["f1"],
                        "precision_asd": record["precision_asd"],
                        "precision_dd": record["precision_dd"],
                        "recall_asd": record["recall_asd"],
                        "recall_dd": record["recall_dd"]
                    }
                #print(metrics)

            except Exception as e:
                print(e)
                exit

            tx.commit()
            print("time to train the model:", time.time() - start)
            return metrics
        
    def write_metrics(self, metrics, csv_file_path, seed):
        mode = "w" if seed == 1 else "a"

        metric_names = list(metrics.keys())

        # Open the CSV file for writing
        with open(csv_file_path, mode=mode, newline="") as file:
            writer = csv.writer(file)
            
            if seed == 1:
                # Write the header row using the metric names
                writer.writerow(metric_names)
            
            # Create a list of metric values in the same order as the headers
            metric_values = [metrics[metric] for metric in metric_names]
            
            # Write the metric values as a single row
            writer.writerow(metric_values)

if __name__ == '__main__':
    dopamine_classification = DopamineNodeClassification(sys.argv[1:])
    
    pipeline_name = "dopamine_pipeline"
    weight_property = 'weight'
    projection_name = "dopamine_projection"
    target_node_labels = ['DopamineGeneticSimilarity', 'DopamineGoSimilarity', 'DopamineGeneticGoSimilarity']
    relationships = ['DOPAMINE_GENETIC_SIMILARITY', 'DOPAMINE_GO_SIMILARITY', 'DOPAMINE_GENETIC_GO_SIMILARITY']
    features = ["centrality","embeddings"]
    results_path = "./node_classification/model_metrics/"

    for feature in features:
        for i in range(0,3):
            print("Using ", target_node_labels[i], " target node labels ", relationships[i], " relationships, and " + feature + " features")
            csv_file_path = results_path + target_node_labels[i] + "_" + feature + ".csv"
            for j in range(1,101): 
                seed = j
                dopamine_classification.initialize_pipeline(pipeline_name)
                dopamine_classification.add_node_properties(pipeline_name, weight_property, feature)
                dopamine_classification.add_split_and_model_configuration(pipeline_name)
                dopamine_classification.project_graph(projection_name, target_node_labels[i], relationships[i], weight_property)
                metrics = dopamine_classification.train_model(projection_name, pipeline_name, "model_name", target_node_labels[i], seed)
                dopamine_classification.write_metrics(metrics, csv_file_path, seed)

            

